import re
import unittest
import unittest.mock
from io import StringIO

import pytest
import yaml
from pydantic import ValidationError

from datahub.ingestion.graph.filters import (
    RemovedStatusFilter,
    SearchFilterRule,
    generate_filter,
)
from datahub.metadata.urns import DataPlatformUrn, QueryUrn, Urn
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_client import compile_filters, compute_entity_types
from datahub.sdk.search_filters import (
    Filter,
    FilterDsl as F,
    _BaseFilter,
    _CustomCondition,
    _filter_discriminator,
    load_filters,
)
from datahub.utilities.urns.error import InvalidUrnError
from tests.test_helpers.graph_helpers import MockDataHubGraph


def test_filters_simple() -> None:
    yaml_dict = {"platform": ["snowflake", "bigquery"]}
    filter_obj: Filter = load_filters(yaml_dict)
    assert filter_obj == F.platform(["snowflake", "bigquery"])
    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(
                    field="platform.keyword",
                    condition="EQUAL",
                    values=[
                        "urn:li:dataPlatform:snowflake",
                        "urn:li:dataPlatform:bigquery",
                    ],
                )
            ]
        }
    ]


def test_filters_and() -> None:
    yaml_dict = {
        "and": [
            {"env": ["PROD"]},
            {"platform": ["snowflake", "bigquery"]},
        ]
    }
    filter_obj: Filter = load_filters(yaml_dict)
    assert filter_obj == F.and_(
        F.env("PROD"),
        F.platform(["snowflake", "bigquery"]),
    )
    platform_rule = SearchFilterRule(
        field="platform.keyword",
        condition="EQUAL",
        values=[
            "urn:li:dataPlatform:snowflake",
            "urn:li:dataPlatform:bigquery",
        ],
    )
    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
                platform_rule,
            ]
        },
        {
            "and": [
                SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
                platform_rule,
            ]
        },
    ]


def test_filters_complex() -> None:
    yaml_dict = yaml.safe_load(
        StringIO("""\
and:
  - env: [PROD]
  - or:
    - platform: [ snowflake, bigquery ]
    - and:
      - platform: [postgres]
      - not:
            domain: [urn:li:domain:analytics]
    - field: customProperties
      condition: EQUAL
      values: ["dbt_unique_id=source.project.name"]
""")
    )
    filter_obj: Filter = load_filters(yaml_dict)
    assert filter_obj == F.and_(
        F.env("PROD"),
        F.or_(
            F.platform(["snowflake", "bigquery"]),
            F.and_(
                F.platform("postgres"),
                F.not_(F.domain("urn:li:domain:analytics")),
            ),
            F.has_custom_property("dbt_unique_id", "source.project.name"),
        ),
    )
    warehouse_rule = SearchFilterRule(
        field="platform.keyword",
        condition="EQUAL",
        values=["urn:li:dataPlatform:snowflake", "urn:li:dataPlatform:bigquery"],
    )
    postgres_rule = SearchFilterRule(
        field="platform.keyword",
        condition="EQUAL",
        values=["urn:li:dataPlatform:postgres"],
    )
    domain_rule = SearchFilterRule(
        field="domains",
        condition="EQUAL",
        values=["urn:li:domain:analytics"],
        negated=True,
    )
    custom_property_rule = SearchFilterRule(
        field="customProperties",
        condition="EQUAL",
        values=["dbt_unique_id=source.project.name"],
    )

    # There's one OR clause in the original filter with 3 clauses,
    # and one hidden in the env filter with 2 clauses.
    # The final result should have 3 * 2 = 6 OR clauses.
    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
                warehouse_rule,
            ],
        },
        {
            "and": [
                SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
                postgres_rule,
                domain_rule,
            ],
        },
        {
            "and": [
                SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]),
                custom_property_rule,
            ],
        },
        {
            "and": [
                SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
                warehouse_rule,
            ],
        },
        {
            "and": [
                SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
                postgres_rule,
                domain_rule,
            ],
        },
        {
            "and": [
                SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]),
                custom_property_rule,
            ],
        },
    ]


def test_entity_subtype_filter() -> None:
    filter_obj_1: Filter = load_filters({"entity_subtype": ["Table"]})
    assert filter_obj_1 == F.entity_subtype("Table")

    # Ensure it works without the list wrapper to maintain backwards compatibility.
    filter_obj_2: Filter = load_filters({"entity_subtype": "Table"})
    assert filter_obj_1 == filter_obj_2


def test_filters_all_types() -> None:
    filter_obj: Filter = load_filters(
        {
            "and": [
                {
                    "or": [
                        {"entity_type": ["dataset"]},
                        {"entity_type": ["chart", "dashboard"]},
                    ]
                },
                {"not": {"entity_subtype": ["Table"]}},
                {"platform": ["snowflake"]},
                {"domain": ["urn:li:domain:marketing"]},
                {
                    "container": ["urn:li:container:f784c48c306ba1c775ef917e2f8c1560"],
                    "direct_descendants_only": True,
                },
                {"env": ["PROD"]},
                {"status": "NOT_SOFT_DELETED"},
                {"glossary_term": ["urn:li:glossaryTerm:data-quality"]},
                {"tag": ["urn:li:tag:data-quality"]},
                {
                    "field": "custom_field",
                    "condition": "GREATER_THAN_OR_EQUAL_TO",
                    "values": ["5"],
                },
            ]
        }
    )
    assert filter_obj == F.and_(
        F.or_(
            F.entity_type("dataset"),
            F.entity_type(["chart", "dashboard"]),
        ),
        F.not_(F.entity_subtype("Table")),
        F.platform("snowflake"),
        F.domain("urn:li:domain:marketing"),
        F.container(
            "urn:li:container:f784c48c306ba1c775ef917e2f8c1560",
            direct_descendants_only=True,
        ),
        F.env("PROD"),
        F.soft_deleted(RemovedStatusFilter.NOT_SOFT_DELETED),
        F.glossary_term("urn:li:glossaryTerm:data-quality"),
        F.tag("urn:li:tag:data-quality"),
        F.custom_filter("custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]),
    )


def test_field_discriminator() -> None:
    with pytest.raises(ValueError, match="Cannot get discriminator for _BaseFilter"):
        _BaseFilter._field_discriminator()

    assert F.entity_type("dataset")._field_discriminator() == "entity_type"
    assert F.not_(F.entity_subtype("Table"))._field_discriminator() == "not"
    assert (
        F.custom_filter(
            "custom_field", "GREATER_THAN_OR_EQUAL_TO", ["5"]
        )._field_discriminator()
        == _CustomCondition._field_discriminator()
    )

    class _BadFilter(_BaseFilter):
        field1: str
        field2: str

    with pytest.raises(
        ValueError,
        match=re.escape(
            "Found multiple fields that could be the discriminator for this filter: ['field1', 'field2']"
        ),
    ):
        _BadFilter._field_discriminator()


def test_filter_discriminator() -> None:
    # Simple filter discriminator extraction.
    assert _filter_discriminator(F.entity_type("dataset")) == "entity_type"
    assert _filter_discriminator({"entity_type": "dataset"}) == "entity_type"
    assert _filter_discriminator({"not": {"entity_subtype": "Table"}}) == "not"
    assert _filter_discriminator({"unknown_field": 6}) == "unknown_field"
    assert _filter_discriminator({"field1": 6, "field2": 7}) is None
    assert _filter_discriminator({}) is None
    assert _filter_discriminator(6) is None

    # Special cases.
    assert (
        _filter_discriminator(
            {
                "field": "custom_field",
                "condition": "GREATER_THAN_OR_EQUAL_TO",
                "values": ["5"],
            }
        )
        == "_custom"
    )
    assert (
        _filter_discriminator(
            {
                "field": "custom_field",
                "condition": "EXISTS",
            }
        )
        == "_custom"
    )
    assert (
        _filter_discriminator(
            {"container": ["urn:li:container:f784c48c306ba1c775ef917e2f8c1560"]}
        )
        == "container"
    )
    assert (
        _filter_discriminator(
            {
                "container": ["urn:li:container:f784c48c306ba1c775ef917e2f8c1560"],
                "direct_descendants_only": True,
            }
        )
        == "container"
    )


def test_tagged_union_error_messages() -> None:
    # With pydantic v2, we get validation errors for each union member
    with pytest.raises(
        ValidationError,
        match=re.compile(
            r"validation error.*entity_type.*Input should be a valid list",
            re.DOTALL,
        ),
    ):
        load_filters({"entity_type": 6})

    # Without discriminators, we get verbose union errors for unknown fields
    with pytest.raises(
        ValidationError,
        match=re.compile(
            r"validation error.*unknown_field.*Extra inputs are not permitted",
            re.DOTALL,
        ),
    ):
        load_filters({"and": [{"unknown_field": 6}]})


def test_filter_before_validators() -> None:
    # Test that we can load a filter from a string.
    # Sometimes we get filters encoded as JSON, and we want to handle those gracefully.
    filter_str = '{\n  "and": [\n    {"entity_type": ["dataset"]},\n    {"entity_subtype": ["Table"]},\n    {"platform": ["snowflake"]}\n  ]\n}'
    assert load_filters(filter_str) == F.and_(
        F.entity_type("dataset"),
        F.entity_subtype("Table"),
        F.platform("snowflake"),
    )
    with pytest.raises(
        ValidationError,
        match=re.compile(
            r"validation error.*Input should be a valid dictionary", re.DOTALL
        ),
    ):
        load_filters("this is invalid json but should not raise a json error")

    # Test that we can load a filter from and-like dictionary.
    # Sometimes we get filters that are not wrapped in an "and" clause.
    filter_str = '{"entity_type": ["dataset"], "entity_subtype": ["Table"], "platform": ["snowflake"]}'
    assert load_filters(filter_str) == F.and_(
        F.entity_type("dataset"),
        F.entity_subtype("Table"),
        F.platform("snowflake"),
    )

    filter_str = '{"entity_type": ["dataset"], "container": ["urn:li:container:f784c48c306ba1c775ef917e2f8c1560"]}'
    assert load_filters(filter_str) == F.and_(
        F.entity_type("dataset"),
        F.container("urn:li:container:f784c48c306ba1c775ef917e2f8c1560"),
    )

    filter_str = '{"entity_type": ["dataset"], "container": ["urn:li:container:f784c48c306ba1c775ef917e2f8c1560"], "direct_descendants_only": true}'
    with pytest.raises(
        ValidationError,
        match=re.compile(
            r"validation error.*Extra inputs are not permitted.*",
            re.DOTALL,
        ),
    ):
        load_filters(filter_str)


def test_owner_filter() -> None:
    """Test basic owner filter functionality."""
    filter_obj: Filter = load_filters({"owner": ["urn:li:corpuser:john"]})
    assert filter_obj == F.owner("urn:li:corpuser:john")

    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(
                    field="owners",
                    condition="EQUAL",
                    values=["urn:li:corpuser:john"],
                )
            ]
        }
    ]


def test_glossary_term_filter() -> None:
    """Test basic glossary term filter functionality."""
    filter_obj: Filter = load_filters(
        {"glossary_term": ["urn:li:glossaryTerm:data-quality"]}
    )
    assert filter_obj == F.glossary_term("urn:li:glossaryTerm:data-quality")

    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(
                    field="glossaryTerms",
                    condition="EQUAL",
                    values=["urn:li:glossaryTerm:data-quality"],
                )
            ]
        }
    ]


def test_owner_filter_mixed_types() -> None:
    """Test owner filter with both user and group URNs."""
    filter_obj: Filter = load_filters(
        {"owner": ["urn:li:corpuser:john", "urn:li:corpGroup:engineering"]}
    )
    assert filter_obj == F.owner(
        ["urn:li:corpuser:john", "urn:li:corpGroup:engineering"]
    )


def test_invalid_owner_filter() -> None:
    """Test validation error for invalid owner URN."""
    with pytest.raises(
        ValidationError, match="Owner must be a valid User or Group URN"
    ):
        F.owner("invalid-owner")


def test_glossary_term_filter_multiple() -> None:
    """Test glossary term filter with multiple terms."""
    filter_obj: Filter = load_filters(
        {
            "glossary_term": [
                "urn:li:glossaryTerm:data-quality",
                "urn:li:glossaryTerm:compliance",
            ]
        }
    )
    assert filter_obj == F.glossary_term(
        ["urn:li:glossaryTerm:data-quality", "urn:li:glossaryTerm:compliance"]
    )


def test_invalid_glossary_term_filter() -> None:
    """Test validation error for invalid glossary term URN."""
    with pytest.raises(
        ValidationError, match="Glossary term must be a valid glossary term URN"
    ):
        F.glossary_term("urn:li:corpuser:john")


def test_tag_filter() -> None:
    """Test basic tag filter functionality."""
    filter_obj: Filter = load_filters({"tag": ["urn:li:tag:data-quality"]})
    assert filter_obj == F.tag("urn:li:tag:data-quality")
    assert filter_obj.compile() == [
        {
            "and": [
                SearchFilterRule(
                    field="tags",
                    condition="EQUAL",
                    values=["urn:li:tag:data-quality"],
                )
            ]
        }
    ]


def test_tag_filter_multiple() -> None:
    """Test tag filter with multiple tags."""
    filter_obj: Filter = load_filters(
        {"tag": ["urn:li:tag:data-quality", "urn:li:tag:production"]}
    )
    assert filter_obj == F.tag(["urn:li:tag:data-quality", "urn:li:tag:production"])


def test_invalid_tag_filter() -> None:
    """Test validation error for invalid tag URN."""
    with pytest.raises(ValidationError, match="Tag must be a valid tag URN"):
        F.tag("urn:li:corpuser:john")


def test_invalid_filter() -> None:
    with pytest.raises(InvalidUrnError):
        F.domain("marketing")


def test_unsupported_not() -> None:
    env_filter = F.env("PROD")
    with pytest.raises(
        ValidationError,
        match="Cannot negate a filter with multiple OR clauses",
    ):
        F.not_(env_filter)


_default_status_filter = {
    "field": "removed",
    "condition": "EQUAL",
    "values": ["true"],
    "negated": True,
}


def test_compute_entity_types() -> None:
    assert compute_entity_types(
        [
            {
                "and": [
                    SearchFilterRule(
                        field="_entityType",
                        condition="EQUAL",
                        values=["DATASET"],
                    )
                ]
            },
            {
                "and": [
                    SearchFilterRule(
                        field="_entityType",
                        condition="EQUAL",
                        values=["CHART"],
                    )
                ]
            },
        ]
    ) == ["DATASET", "CHART"]


def test_compute_entity_types_deduplication() -> None:
    types, _ = compile_filters(
        load_filters(
            {
                "and": [
                    {"entity_type": ["DATASET"]},
                    {"entity_type": ["DATASET"]},
                    {"entity_subtype": "Table"},
                    {"not": {"platform": ["snowflake"]}},
                ]
            }
        )
    )
    assert types == ["DATASET"]


def test_compile_filters() -> None:
    filter = F.and_(F.env("PROD"), F.platform("snowflake"))
    expected_filters = [
        {
            "and": [
                {
                    "field": "origin",
                    "condition": "EQUAL",
                    "values": ["PROD"],
                },
                {
                    "field": "platform.keyword",
                    "condition": "EQUAL",
                    "values": ["urn:li:dataPlatform:snowflake"],
                },
                _default_status_filter,
            ]
        },
        {
            "and": [
                {
                    "field": "env",
                    "condition": "EQUAL",
                    "values": ["PROD"],
                },
                {
                    "field": "platform.keyword",
                    "condition": "EQUAL",
                    "values": ["urn:li:dataPlatform:snowflake"],
                },
                _default_status_filter,
            ]
        },
    ]
    types, compiled = compile_filters(filter)
    assert types is None
    assert compiled == expected_filters


def test_compile_no_default_status() -> None:
    filter = F.and_(
        F.platform("snowflake"), F.soft_deleted(RemovedStatusFilter.ONLY_SOFT_DELETED)
    )

    _, compiled = compile_filters(filter)

    # Check that no status filter was added.
    assert compiled == [
        {
            "and": [
                {
                    "condition": "EQUAL",
                    "field": "platform.keyword",
                    "values": ["urn:li:dataPlatform:snowflake"],
                },
                {
                    "condition": "EQUAL",
                    "field": "removed",
                    "values": ["true"],
                },
            ],
        },
    ]


def test_generate_filters() -> None:
    types, compiled = compile_filters(
        F.and_(
            F.entity_type(QueryUrn.ENTITY_TYPE),
            F.custom_filter("origin", "EQUAL", [DataPlatformUrn("snowflake").urn()]),
        )
    )
    assert types == ["QUERY"]
    assert compiled == [
        {
            "and": [
                {"field": "_entityType", "condition": "EQUAL", "values": ["QUERY"]},
                {
                    "field": "origin",
                    "condition": "EQUAL",
                    "values": ["urn:li:dataPlatform:snowflake"],
                },
                _default_status_filter,
            ]
        }
    ]

    assert generate_filter(
        platform=None,
        platform_instance=None,
        env=None,
        container=None,
        status=RemovedStatusFilter.NOT_SOFT_DELETED,
        extra_filters=None,
        extra_or_filters=compiled,
    ) == [
        {
            "and": [
                # This filter appears twice - once from the compiled filters, and once
                # from the status arg to generate_filter.
                _default_status_filter,
                {
                    "field": "_entityType",
                    "condition": "EQUAL",
                    "values": ["QUERY"],
                },
                {
                    "field": "origin",
                    "condition": "EQUAL",
                    "values": ["urn:li:dataPlatform:snowflake"],
                },
                _default_status_filter,
            ]
        }
    ]


def test_get_urns() -> None:
    graph = MockDataHubGraph()

    with unittest.mock.patch.object(graph, "execute_graphql") as mock_execute_graphql:
        mock_execute_graphql.return_value = {
            "scrollAcrossEntities": {
                "nextScrollId": None,
                "searchResults": [{"entity": {"urn": "urn:li:corpuser:datahub"}}],
            }
        }

        result_urns = ["urn:li:corpuser:datahub"]
        mock_execute_graphql.return_value = {
            "scrollAcrossEntities": {
                "nextScrollId": None,
                "searchResults": [{"entity": {"urn": urn}} for urn in result_urns],
            }
        }

        client = DataHubClient(graph=graph)
        urns = client.search.get_urns(
            filter=F.and_(
                F.entity_type("corpuser"),
            )
        )
        assert list(urns) == [Urn.from_string(urn) for urn in result_urns]

        assert mock_execute_graphql.call_count == 1
        assert "scrollAcrossEntities" in mock_execute_graphql.call_args.args[0]
        mock_execute_graphql.assert_called_once_with(
            unittest.mock.ANY,
            variables={
                "types": ["CORP_USER"],
                "query": "*",
                "orFilters": [
                    {
                        "and": [
                            {
                                "field": "_entityType",
                                "condition": "EQUAL",
                                "values": ["CORP_USER"],
                            },
                            {
                                "field": "removed",
                                "condition": "EQUAL",
                                "values": ["true"],
                                "negated": True,
                            },
                        ]
                    }
                ],
                "batchSize": unittest.mock.ANY,
                "scrollId": None,
                "skipCache": False,
                "includeSoftDeleted": None,
            },
        )


def test_get_urns_with_skip_cache() -> None:
    graph = MockDataHubGraph()

    with unittest.mock.patch.object(graph, "execute_graphql") as mock_execute_graphql:
        result_urns = ["urn:li:corpuser:datahub"]
        mock_execute_graphql.return_value = {
            "scrollAcrossEntities": {
                "nextScrollId": None,
                "searchResults": [{"entity": {"urn": urn}} for urn in result_urns],
            }
        }

        client = DataHubClient(graph=graph)
        urns = client.search.get_urns(
            filter=F.and_(
                F.entity_type("corpuser"),
            ),
            skip_cache=True,
        )
        assert list(urns) == [Urn.from_string(urn) for urn in result_urns]

        assert mock_execute_graphql.call_count == 1
        mock_execute_graphql.assert_called_once_with(
            unittest.mock.ANY,
            variables={
                "types": ["CORP_USER"],
                "query": "*",
                "orFilters": [
                    {
                        "and": [
                            {
                                "field": "_entityType",
                                "condition": "EQUAL",
                                "values": ["CORP_USER"],
                            },
                            {
                                "field": "removed",
                                "condition": "EQUAL",
                                "values": ["true"],
                                "negated": True,
                            },
                        ]
                    }
                ],
                "batchSize": unittest.mock.ANY,
                "scrollId": None,
                "skipCache": True,
                "includeSoftDeleted": None,
            },
        )
